# GBA Song Editor
#
# Copyright 2010-2011 Karl A. Knechtel.
#
# Miscellaneous implementation details.
#
# Licensed under the Generic Non-Commercial Copyleft Software License,
# Version 1.1 (hereafter "Licence"). You may not use this file except
# in the ways outlined in the Licence, which you should have received
# along with this file.
#
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.


from struct import pack, unpack
from array import array
from cli import PEBKAC, display


# Misc conversions.
GBA_ROM_OFFSET = 0x08000000
def as_offset(pointer): return None if pointer == 0 else pointer - GBA_ROM_OFFSET
def as_pointer(offset): return 0 if offset == None else offset + GBA_ROM_OFFSET


class Chunk(array):
	def __new__(cls, data):
		return array.__new__(Chunk, 'B', data)


	def __getslice__(self, x, y): return Chunk(array.__getslice__(self, x, y))


	def read_int(self, location):
		return sum(self[location + i] << (i * 8) for i in range(4))


	def write_int(self, location, value):
		self[location:location + 4] = array(
			'B', ((value >> (i * 8)) & 0xFF for i in range(4))
		)


	def __str__(self): return self.tostring()


def offset(chunk, location, delta):
	chunk.write_int(location, chunk.read_int(location) + delta)


class BadTrack(Exception):
	"""Represents a problem reading a track due to an unsupported instrument
	Treating this as a non-fatal error."""
	pass


def save_chunk(filename, chunk):
	try:
		with file(filename, 'wb') as f:
			chunk.tofile(f)
	except IOError as ioe: raise PEBKAC(str(ioe))


class State(object):
	__slots__ = ['ROM', 'ROM_file', 'song_data', 'song_source']

	def __init__(self):
		self.reset_ROM()
		self.reset_song()


	def reset_ROM(self):
		# ROM data and metadata.
		self.ROM = None # will be a Chunk.
		self.ROM_file = None


	def reset_song(self):
		# Song data and metadata.
		self.song_data = None
		self.song_source = None


def next_byte(source, position):
	return source[position], position + 1


def offset_track_pointers(data, destination, start, end):
	skip = 0
	my_offset = as_pointer(start + destination)
	for i in range(start, end):
		if skip:
			skip -= 1
			continue
		if data[i] in (0xB2, 0xB3):
			offset(data, i + 1, my_offset)
			skip = 4
		# Don't need to consider instrument map values this time;
		# they have already been translated fully.


def prepare_for_burning(data, destination):
	# Copy the data, because we modify it locally and
	# the original data must not be changed (in case we
	# want to burn it again).
	data = data[:]

	# Parse metadata.
	track_count, instrument_count = data[0], data[3]
	data[2] = 0x0A # Do these values matter?
	data[3] = 0x80 # This is just copying what the stock ROM does...

	track_starts = [data.read_int(8 + i * 4) for i in xrange(track_count)]

	# Adjust pointers within metadata.
	# HAX: If the instrument count is zero, the song was converted
	# from a MIDI, and therefore the instrument map pointer is absolute
	# and should not be adjusted.
	if instrument_count != 0:
		instrument_start = data.read_int(4)
		offset(data, 4, as_pointer(destination))
	else:
		instrument_start = len(data)

	for i in xrange(track_count):
		offset(data, 8 + i * 4, as_pointer(destination))

	# Adjust pointers within track data.
	for start, end in zip(track_starts, track_starts[1:] + [instrument_start]):
		offset_track_pointers(data, destination, start, end)

	# Adjust pointers within local instrument map.
	# Note that the first "instrument" is a percussion mapping,
	# and its zero pointer is relative to the instrument map, since
	# the instrument map is being reused as a percussion map. However
	# the other offsets are relative to the start of the *sample* data.
	for i in range(instrument_count):
		offset(
			data,
			instrument_start + 12 * i + 4,
			as_pointer(
				destination + instrument_start +
				(0 if i == 0 else 12 * instrument_count)
			)
		)

	return data


def do_burning(target, handle, data, where):
	target.write_int(handle, as_pointer(where))
	target[where:where + len(data)] = prepare_for_burning(data, where)


def rip_tracks(source, song, instrument_map, track_count):
	bad_tracks = []

	track_data = Chunk('')
	track_starts = []
	for i in xrange(track_count):
		try:
			track_start = len(track_data)
			track_data += process_track(source, source.read_int(song + 8 + (4 * i)), instrument_map)
			track_starts.append(track_start)
		except BadTrack as bt:
			bad_tracks.append(bt)

	while len(track_data) % 4: track_data.append(0)
	return track_data, track_starts, bad_tracks


def process_song(track_data, track_starts, instrument_map):
	# The existing data puts the tracks before metadata, and the instrument maps
	# and samples elsewhere. However, for convenience, we will put the metadata
	# first, then tracks, then instrument map, then samples.

	# Assemble metadata. First, a track count. Also remember instrument count for
	# later.
	track_count = len(track_starts)
	song_data = Chunk((track_count, 0, 0, instrument_map.count()))
	# Relative pointer to the instrument map:
	# length of metadata + length of track data.
	metadata_length = 4 * (2 + track_count)
	song_data.write_int(4, metadata_length + len(track_data))
	# Relative pointers to tracks.
	for track_start in track_starts:
		song_data.write_int(len(song_data), metadata_length + track_start)

	# Use += so that this stays a Chunk instead of decaying to array.array.
	song_data += track_data + instrument_map.as_bytes()
	return song_data


class InstrumentMap(object):
	"""Maps non-percussion instruments to sequential values 1..n,
	and extracts the corresponding samples. If percussion is
	encountered, all percussion maps are re-assigned ID 0,
	and the mapped percussion samples are assigned sequential
	values in the same mapping."""


	def __init__(self, ROM, instrument_map_pointer):
		self.instrument_codes = Chunk('\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
		self.instrument_mapping = {}
		self.sample_data = Chunk('')
		self.ROM = ROM
		self.instrument_map_offset = as_offset(instrument_map_pointer)


	def count(self):
		bytecount = len(self.instrument_codes)
		assert bytecount % 12 == 0
		result = bytecount / 12
		assert result < 0x80 # allocated too many instruments?
		return result


	def get_instrument(self, index, base = None):
		if base == None: base = self.instrument_map_offset
		base += 12 * index
		result = self.ROM[base:base + 12]
		return self.ROM[base:base + 12]


	def translate(self, original_id):
		if original_id in self.instrument_mapping:
			return self.instrument_mapping[original_id]

		instrument_code = self.get_instrument(original_id)
		if instrument_code[0] == 0x80 and instrument_code[1] == 0:
			self.instrument_mapping[original_id] = 0
			return 0

		return self._prepare(instrument_code, original_id)


	def translate_percussion(self, original_id, pitch):
		if (original_id, pitch) in self.instrument_mapping:
			return self.instrument_mapping[(original_id, pitch)]

		instrument_code = self.get_instrument(pitch, as_offset(self.get_instrument(original_id).read_int(4)))

		return self._prepare(instrument_code, (original_id, pitch))


	def _prepare(self, instrument_code, key):
		if instrument_code[0] not in (0x00, 0x08) or instrument_code[1] != 0x3C:
			raise BadTrack('%02X %02X' % (instrument_code[0], instrument_code[1]))

		# Look up the sample, read it and add padding.
		sample_location = as_offset(instrument_code.read_int(4))
		sample_length = self.ROM.read_int(sample_location + 12)
		current_sample = self.ROM[sample_location:sample_location + 16 + sample_length]
		while len(current_sample) % 4: current_sample.append(0)

		# Fix instrument pointer to be an offset relative to start of sample data.
		# The pointer for the first instrument - which is the percussion map - is
		# of course relative to the start of instrument data, being zero. The
		# burn procedure is aware of this.
		instrument_code.write_int(4, len(self.sample_data))
		self.sample_data += current_sample

		result = self.count()
		self.instrument_mapping[key] = result
		self.instrument_codes += instrument_code
		return result


	def as_bytes(self):
		return self.instrument_codes + self.sample_data


def process_track(ROM, start_pointer, instrument_map):
	# Read data up until the terminating 0xB1 byte.
	data = Chunk('')
	position = as_offset(start_pointer)
	if position < 0 or position >= len(ROM):
		raise PEBKAC("Found an invalid track pointer. Song offset is probably invalid. ")

	# Transform the data.
	percussion = None
	while True:
		byte = ROM[position]
		position += 1
		data.append(byte)
		if byte == 0xB1:
			break

		if byte in (0xB2, 0xB3):
			data.write_int(len(data), ROM.read_int(position) - start_pointer)
			position += 4
		elif byte == 0xBD:
			next_byte = ROM[position]
			position += 1
			translated = instrument_map.translate(next_byte)
			if translated == 0: percussion = next_byte
			data.append(translated)
		elif byte in (0xBB, 0xBC, 0xBE, 0xBF, 0xC0, 0xC1):
			# These commands take a data byte that must not be processed.
			data.append(ROM[position])
			position += 1
		elif percussion != None and byte < 0x80:
			data[-1] = instrument_map.translate_percussion(percussion, byte)
			# There might be a volume marker, and then a 'gate' byte
			# For now, assuming that any subsequent low-value bytes are extra data
			# that should be passed as-is - even though previous experimentation suggested
			# that these bytes could be used to specify a chord...
			while ROM[position] < 0x80: # Volume marker
				data.append(ROM[position])
				position += 1

	return data


def as_extended(value):
	exponent = 0
	tmp = value
	while tmp: tmp /= 2; exponent += 1
	# (1 << exponent) >= value
	mantissa = value << (64 - exponent) # shift it into place
	# add the exponent bias and subtract 1 (since we overshot)
	exponent += (1 << 14) - 2
	return (exponent, mantissa)


def deref(ptr, ROM):
	return as_offset(ROM.read_int(ptr))


def header(frequency, count, fmt):
	if fmt == 'aiff':
		exponent, mantissa = as_extended(frequency)
		return pack(
			'>4sl8slhLhHQ' + '4sl6b7h' + '4slH' + '4sl4s2b' + '4sl2L',
			'FORM', count + 98, 'AIFFCOMM', 18, 1, count, 8, exponent, mantissa,
			'INST', 20, 60, 0, 60, 127, 0, 127, 0, 0, 0, 0, 0, 0, 0,
			'MARK', 2, 0,
			'APPL', 6, 'auFM', 0, 0,
			'SSND', count + 8, 0, 0
		)
	#elif fmt == 'wav': pass


def body(pcm, fmt):
	if fmt in ('aiff', 'wav'): return pcm


def write_sample(name, frequency, data, fmt):
	count = len(data)
	with file('.'.join((name, fmt)), 'wb') as out:
		out.write(header(frequency, count, fmt))
		out.write(body(data, fmt))


def write_samples(ROM, folder, ptrs, fmt):
	import os
	try: os.makedirs(folder)
	except: pass # already exists; no problem
	for i, ptr in enumerate(ptrs):
		frequency = ROM.read_int(ptr + 4) >> 10
		count = ROM.read_int(ptr + 12) + 1
		data = ROM[ptr + 16 : ptr + 16 + count]
		display('.', '')
		write_sample(
			os.path.join(folder, '%d - %06X' % (i, ptr)), frequency, data, fmt
		)


def do_dump(ROM, table_start, table_end, folder, fmt):
	sample_ptrs = set()
	for ptr in xrange(table_start, table_end, 8):
		song_ptr = deref(ptr, ROM)
		if song_ptr == table_end: continue # dummy pointer
		song_ptr += 4 # skip track count info, to instrument map ptr
		imap_base = deref(song_ptr, ROM)
		for imap_ptr in xrange(imap_base, imap_base + 128 * 12, 12):
			itype = deref(imap_ptr, ROM) & 0xff
			if itype in (0, 8):
				sample_ptr = deref(imap_ptr + 4, ROM)
				if sample_ptr != None and sample_ptr >= 0 and sample_ptr < 0x1000000:
					sample_ptrs.add(sample_ptr)
			#elif instrument_type in (0x40, 0x80):
				# display("skipping...")
	write_samples(ROM, folder, sorted(sample_ptrs), fmt)
